import json
import numpy as np
import tensorflow.keras as keras
# from keras import models, utils
import music21 as m21
from preprocess import SEQUENCE_LENGTH, MAPPING_PATH

MODEL_PATH = 'model2.h5'

class MelodyGenerator:

    def __init__(self, model_path=MODEL_PATH):

        self.model_path = model_path
        self.model = keras.models.load_model(model_path)

        with open(MAPPING_PATH, 'r') as fp:
            self._mappings = json.load(fp)

        self._start_symbols = ['/'] * SEQUENCE_LENGTH

    def generate_melody(self, seed, num_steps, max_sequence_length, temperature):
        # create seed with start symbols
        seed = seed.split()
        melody = seed
        seed = self._start_symbols + seed
        # map seed to int
        seed = [self._mappings[symbol] for symbol in seed]
        for _ in range(num_steps):
            # limit the seed to max_sequence_length
            seed = seed[-max_sequence_length:]
            # one-hot encode the seed
            onehot_seed = keras.utils.to_categorical(seed, num_classes=len(self._mappings))
            onehot_seed = onehot_seed[np.newaxis, ...]
            # make a prediction
            probabilities = self.model.predict(onehot_seed)[0]
            output_int = self._sample_with_temperature(probabilities, temperature)
            # update seed
            seed.append(output_int)
            # map int to our encoding
            output_symbol = [k for k, v in self._mappings.items() if v == output_int][0]
            # check whether we're at the end of a melody
            if output_symbol == '/':
                break
            # update the melody
            melody.append(output_symbol)
        return melody

    def _sample_with_temperature(self, probabilities, temperature):
        # temperature -> infinity   nearly random
        # temperature -> 0     determined
        # temperature = 1   nothing changes
        predictions = np.log(probabilities) / temperature
        probabilities = np.exp(predictions) / np.sum(np.exp(predictions))

        choices = range(len(probabilities))  # [0, 1, 2, 3, 4]
        index = np.random.choice(choices, p=probabilities)

        return index

    def save_melody(self, melody, step_duration=0.25, format='midi', file_name='mel.midi'):
        # create a music21 stream
        stream = m21.stream.Stream()
        # parse all the symbols in the melody and create note/rest objects
        # 60 _ _ _ r _ 62 _
        start_symbol = None
        step_counter = 1
        for i, symbol in enumerate(melody):
            # handle case in which we have a note/rest
            if symbol != '_' or i + 1 == len(melody):
                # ensure we're dealing with note/rest beyond the first one
                if start_symbol is not None:
                    quarter_length_duration = step_duration * step_counter  # 0.25 * 4
                    # handle rest
                    if start_symbol == 'r':
                        m21_event = m21.note.Rest(quarterLength=quarter_length_duration)
                    # handle note
                    else:
                        m21_event = m21.note.Note(int(start_symbol), quarterLength=quarter_length_duration)
                    stream.append(m21_event)
                    # reset the step counter
                    step_counter = 1
                start_symbol = symbol
            # handle case in which we have a prolongation sign '_'
            else:
                step_counter += 1
        # write the m21 stream to a midi file
        stream.write(format, file_name)


if __name__ == '__main__':
    mg = MelodyGenerator()
    seed = '55 _ _ _ 60 _ _ _ 55 _ _ _ 55 _'
    seed2 = '67 _ 67 _ 67 _ _ 65 64 _ 64 _ 64 _ _'
    seed3 = '55 _ 50 _ 64 _ _ 65 50 _ 58 _ 64 _ _'
    seed4_intro = '68 78 _ _ 73 _ 73 _ 71 _ _ _ 55 _ _ 55 55 _ _ 50 _ _ 58 _ _'
    melody = mg.generate_melody(seed4_intro, 200, SEQUENCE_LENGTH, 2.6)
    print(melody)
    mg.save_melody(melody, file_name='generated_musics/model2_13.midi')